# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing interface methods."""
import types
from collections import OrderedDict
from functools import wraps

from mindspore import context
from mindspore import log as logger
from .tensor import Tensor as MsTensor
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from ..parallel._ps_context import _is_role_pserver
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
    _get_parameter_broadcast

# store ms_function class compiled pipeline cache
ms_compile_cache = {}

BROADCAST_PHASE = "_broadcast_"


def _convert_function_arguments(fn, *args):
    """
    Process the fn default parameters.

    Args:
        fn (Function): The function to be parsed.
        args (tuple): The parameters of the function.
    """
    arguments_dict = OrderedDict()
    parse_method = None
    if isinstance(fn, (types.FunctionType, types.MethodType)):
        parse_method = fn.__name__
        index = 0
        for value in args:
            arguments_dict[f'arg{index}'] = value
            index = index + 1
        logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict)
        converted = True
    else:
        logger.warning("Find error: fn isn't function or method")
        converted = False
    return converted, arguments_dict, parse_method


def _wrap_func(fn):
    """
    Wrapper function, convert return data to tensor or tuple of tensor.

    Args:
        fn (Function): The function need be wrapped.

    Returns:
        Function, a new function with return suitable format data.
    """

    @wraps(fn)
    def wrapper(*arg, **kwargs):
        results = fn(*arg, **kwargs)

        def _convert_data(data):
            if isinstance(data, Tensor) and not isinstance(data, MsTensor):
                return MsTensor(data)
            if isinstance(data, tuple):
                return tuple(_convert_data(x) for x in data)
            if isinstance(data, list):
                return list(_convert_data(x) for x in data)
            return data

        return _convert_data(results)

    return wrapper


def _exec_init_graph(obj, init_phase):
    """Execute the parameter initializer graph."""
    inst_executor = Executor_.get_instance()
    param_dict = OrderedDict()
    for name, param in obj.parameters_dict().items():
        if not param.is_init:
            param_dict[name] = param
            param.is_init = True
            param.data.init_flag = True

    if param_dict:
        inst_executor.run_init_graph(param_dict, init_phase)


class _MindSporeFunction:
    """
    Represents a function compiled by mind expression.

    _MindSporeFunction will compile the original function for every combination
    of argument types and shapes it is given (as well as their values, optionally).

    Args:
        fn (Function): The root function to compile.
        input_signature (Function): User defines signature to verify input.
        obj (Object): If function is a method, obj is the owner of function,
             else, obj is none.
    """

    def __init__(self, fn, input_signature=None, obj=None):
        self.fn = fn
        self.save_graphs = context.get_context("save_graphs")
        self.save_graphs_path = context.get_context("save_graphs_path")
        self.input_signature = input_signature
        self.obj = None
        self.identify_obj = None
        if hasattr(obj, fn.__name__):
            self.obj = obj
        elif obj is not None:
            self.identify_obj = obj
        self._executor = Executor_.get_instance()

    def build_data_init_graph(self, graph_name):
        """Build GE data graph and init graph for the given graph name."""
        if self.obj is None:
            logger.warning("Make sure parameter should not be used in function")
            para_dict = OrderedDict()
            self._executor.build_data_graph(para_dict, graph_name)
            return
        self._executor.build_data_graph(self.obj.parameters_dict(), graph_name, self.obj.parameters_broadcast_dict())
        init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
        _exec_init_graph(self.obj, init_phase)

    def compile(self, arguments_dict, method_name):
        """Returns pipeline for the given args."""
        args_list = tuple(arguments_dict.values())
        arg_names = tuple(arguments_dict.keys())

        # remove first self parameter when fn is a method
        if self.obj is not None:
            args_list = args_list[1:]
            arg_names = arg_names[1:]

        # verify the signature for both function and method
        if self.input_signature is not None:
            signatures = []
            for sig_spec in self.input_signature:
                if not isinstance(sig_spec, MetaTensor):
                    raise TypeError("Input_signature is not MetaTensor")
                signatures.append(sig_spec)
            is_valid_input = verify_inputs_signature(signatures, args_list)
            if not is_valid_input:
                raise ValueError("Inputs is incompatible with input signature!")

        dic = dict(zip(arg_names, args_list))
        generate_name = self.fn.__module__ + "." + self.fn.__name__
        self.fn.__parse_method__ = method_name

        # replace key with obj info and object ext info when fn is a method
        if self.obj is not None:
            self.obj.__parse_method__ = method_name
            generate_name = self.obj.__module__ + "."
            if self.obj.__class__.__name__ != "ClipByNorm":
                generate_name = generate_name + str(self.obj.create_time)
        if self.identify_obj is not None:
            generate_name = generate_name + str(id(self.identify_obj))

        key = generate_key(generate_name, dic)
        phase = str(key[1]) + generate_name
        if key not in ms_compile_cache.keys():
            is_compile = False
            if self.obj is None:
                is_compile = self._executor.compile(self.fn, args_list, phase, True)
            else:
                is_compile = self._executor.compile(self.obj, args_list, phase, True)
            if not is_compile:
                raise RuntimeError("Executor compile failed.")
            if context.get_context("enable_ge"):
                self.build_data_init_graph(phase)
            # since function can be redefined, we only cache class method pipeline
            if self.obj is not None or self.identify_obj is not None:
                ms_compile_cache[key] = phase
            return phase

        return ms_compile_cache[key]

    @_wrap_func
    def __call__(self, *args):
        init_backend()
        converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args)
        if not converted:
            raise RuntimeError('Process function parameter is failure')

        args_list = tuple(arguments_dict.values())
        if self.obj is not None:
            args_list = args_list[1:]

        phase = self.compile(arguments_dict, parse_method)

        if context.get_context("precompile_only"):
            return None
        return self._executor(args_list, phase)


def ms_function(fn=None, obj=None, input_signature=None):
    """
    Create a callable MindSpore graph from a python function.

    This allows the MindSpore runtime to apply optimizations based on graph.

    Args:
        fn (Function): The Python function that will be run as a graph. Default: None.
        obj (Object): The Python Object that provides the information for identifying the compiled function.Default:
            None.
        input_signature (MetaTensor): The MetaTensor which describes the input arguments. The MetaTensor specifies
            the shape and dtype of the Tensor and they will be supplied to this function. If input_signature
            is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept
            `**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise,
            TypeError will be raised. Default: None.

    Returns:
        Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
        None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
        equal to the case when `fn` is not None.

    Examples:
        >>> from mindspore.ops import functional as F
        ...
        >>> def tensor_add(x, y):
        ...     z = x + y
        ...     return z
        ...
        >>> @ms_function
        ... def tensor_add_with_dec(x, y):
        ...     z = x + y
        ...     return z
        ...
        >>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)),
        ...                               MetaTensor(mindspore.float32, (1, 1, 3, 3))))
        ... def tensor_add_with_sig(x, y):
        ...     z = x + y
        ...     return z
        ...
        >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
        >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
        ...
        >>> tensor_add_graph = ms_function(fn=tensor_add)
        >>> out = tensor_add_graph(x, y)
        >>> out = tensor_add_with_dec(x, y)
        >>> out = tensor_add_with_sig(x, y)
    """

    def wrap_mindspore(func):
        @wraps(func)
        def staging_specialize(*args):
            process_obj = obj
            if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
                process_obj = args[0]
            return _MindSporeFunction(func, input_signature, process_obj)(*args)

        return staging_specialize

    if fn is not None:
        return wrap_mindspore(fn)
    return wrap_mindspore


def _generate_pip_args(obj, *args, method="construct"):
    """Generate arguments for pipeline."""
    if hasattr(obj, method):
        fn = getattr(obj, method)
    else:
        raise AttributeError('The process method is not exist')
    converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
    if not converted:
        raise RuntimeError('Process method parameter is failure')
    args_list = tuple(arguments_dict.values())
    args_names = tuple(arguments_dict.keys())
    obj.__parse_method__ = parse_method
    return args_names, args_list


class _PynativeExecutor:
    """
    An pynative executor used to compile/manage/run graph.

    Returns:
        Graph, return the result of pipeline running.
    """

    def __init__(self):
        self._executor = PynativeExecutor_.get_instance()

    def new_graph(self, obj, *args, **kwargs):
        self._executor.new_graph(obj, *args, *(kwargs.values()))

    def end_graph(self, obj, output, *args, **kwargs):
        self._executor.end_graph(obj, output, *args, *(kwargs.values()))

    def check_graph(self, obj, *args, **kwargs):
        return self._executor.check_graph(obj, *args, *(kwargs.values()))

    def grad(self, grad, obj, weights, *args, **kwargs):
        self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))

    def clear(self, cell_id=""):
        self._executor.clear(cell_id)

    def sync(self):
        self._executor.sync()

    def set_grad_flag(self, flag):
        self._executor.set_grad_flag(flag)

    def __call__(self, obj, *args, **kwargs):
        args = args + tuple(kwargs.values())
        return self._executor(obj, args, "")


class _Executor:
    """
    An executor used to compile/manage/run graph.

    Including data_graph, train_graph, eval_graph and predict graph.

    Returns:
        Graph, return the result of pipeline running.
    """

    def __init__(self):
        # create needed graph by lazy mode
        self.is_init = False
        self._executor = Executor_.get_instance()
        self.compile_cache = {}

    def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
                     input_indexs, phase='dataset'):
        """
        Initialization interface for calling data subgraph.

        Args:
            queue_name (str): The name of tdt queue on the device.
            dataset_size (int): The size of dataset.
            batch_size (int): The size of batch.
            dataset_types (list): The output types of element in dataset.
            dataset_shapes (list): The output shapes of element in dataset.
            input_indexs (list): The index of data with net.
            phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.

        Returns:
            bool, specifies whether the data subgraph was initialized successfully.
        """
        if not init_exec_dataset(queue_name=queue_name,
                                 size=dataset_size,
                                 batch_size=batch_size,
                                 types=dataset_types,
                                 shapes=dataset_shapes,
                                 input_indexs=input_indexs,
                                 phase=phase):
            raise RuntimeError("Failure to init and dataset subgraph!")
        return True

    def _build_data_graph(self, obj, phase):
        self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())

    def _get_auto_split_param_names(self, parameter_layout_dict):
        auto_split_params = {}
        for key, value in parameter_layout_dict.items():
            for dim in value[1]:
                if dim != -1:
                    auto_split_params[key] = value
                    break
        auto_split_param_names = (param_name for param_name in auto_split_params)
        return auto_split_param_names

    def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase):
        """Build broadcast graph."""
        from mindspore.nn.wrap.cell_wrapper import _BroadCastCell

        if not broadcast_params_dict:
            broadcast_params_dict = {}
        broadcast_params = []
        for param in broadcast_params_dict.values():
            broadcast_params.append(Tensor(param.asnumpy()))
        _broadcast_net = _BroadCastCell(broadcast_params)
        _broadcast_net.phase = broadcast_phase
        broadcasted_params = _broadcast_net()
        for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
            broadcast_params_dict[param_name].set_data(param)

    def _set_dataset_mode(self, args_list):
        """set dataset mode."""
        # decide whether to sink based on whether the inputs is virtual or args_list is ()
        if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
                (args_list is not None and args_list == ()):
            _set_dataset_mode_config('sink')
        else:
            _set_dataset_mode_config('normal')

    def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
        """
        Compiles graph.

        Args:
            obj (Function/Cell): The function or cell instance need compile.
            args (tuple): Function or cell input arguments.
            phase (str): The name of compile phase. Default: 'predict'.
            do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
            auto_parallel_mode: When set to True, use auto parallel mode to compile graph.

        Return:
            Str, the full phase of the cell.
            Bool, if the graph has been compiled before, return False, else return True.
        """
        from mindspore import nn
        from mindspore.ops.composite import GradOperation

        class InputsToAttrCell(nn.Cell):
            """The cell that converts non-tensor inputs to attr."""

            def __init__(self, net, args_names, non_tensor_inputs):
                super(InputsToAttrCell, self).__init__()
                self.net = net
                self.args_names = args_names
                self.non_tensor_inputs = non_tensor_inputs
                self.inputs_to_attr = True

            def construct(self, *tensor_inputs):
                real_inputs = ()
                index = 0
                for i in args_names:
                    if i in self.non_tensor_inputs.keys():
                        real_inputs += (self.non_tensor_inputs[i],)
                    else:
                        real_inputs += (tensor_inputs[index],)
                        index += 1
                return self.net(*real_inputs)

        args_names, args_list = _generate_pip_args(obj, *args)
        if not hasattr(obj, "inputs_to_attr"):
            dic = dict(zip(args_names, args_list))
            key = generate_key(phase, dic)
            obj.phase_prefix = str(key[1])
            if 'export' in phase:
                phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
            else:
                phase = obj.phase_prefix + phase + '.' + str(obj.create_time)

            if phase in self.compile_cache.keys():
                logger.debug("%r graph has existed.", phase)
                return phase, False

        if getattr(obj, "support_non_tensor_inputs", None):
            for i in obj.__dict__.values():
                if isinstance(i, GradOperation):
                    raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, "
                                     "only support forward net.")
            attrs = {}
            inputs = []
            for key, value in dic.items():
                if not isinstance(value, (Tensor, MetaTensor)):
                    attrs[key] = value
                else:
                    inputs.append(value)
            if attrs:
                inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
                return self.compile(inputs_to_attr_cell, *inputs, phase=phase)

        obj.check_names()
        _check_full_batch()
        self._set_dataset_mode(args_list)

        is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
        if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
            args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
            _, args_list = _generate_pip_args(obj, *args_full)

        enable_debug_runtime = context.get_context("enable_debug_runtime")
        enable_ge = context.get_context("enable_ge")
        use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
        result = self._executor.compile(obj, args_list, phase, use_vm)
        self.compile_cache[phase] = phase
        if not result:
            raise RuntimeError("Executor compile failed.")
        graph = self._executor.get_func_graph(phase)

        if graph is None:
            logger.error("%r graph compile failed.", phase)
        if not do_convert:
            return phase, True

        if auto_parallel_mode:
            obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
        replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
        if not enable_debug_runtime or enable_ge:
            if auto_parallel_mode:
                obj.load_parameter_slice(None)

        self._updata_param_node_default_input(phase, replace)

        # set parallel inputs in sink mode
        if auto_parallel_mode and is_sink_mode:
            obj.set_parallel_input_with_inputs(*args)

        # the following GE init process is not needed when use vm or ms backend
        if enable_ge:
            self._build_data_graph(obj, phase)

            if "export" not in phase:
                init_phase = "init_subgraph" + "." + str(obj.create_time)
                _exec_init_graph(obj, init_phase)
        elif not enable_ge and "export" in phase:
            self._build_data_graph(obj, phase)
        elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
            auto_split_param_names = []
            if auto_parallel_mode:
                auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)

            broadcast_params_dict = obj.parameters_broadcast_dict()
            if auto_split_param_names and broadcast_params_dict:
                broadcast_params_dict = OrderedDict()
                for param_name, param in obj.parameters_broadcast_dict().items():
                    if param_name not in auto_split_param_names:
                        broadcast_params_dict[param_name] = param
            broadcast_phase = "_broadcast_subgraph"
            self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)

        return phase, True

    def _updata_param_node_default_input(self, phase, replace):
        new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
        return self._executor.updata_param_node_default_input(phase, new_param)

    def _get_shard_strategy(self, obj):
        real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
        return self._executor.get_strategy(real_phase)

    def _get_num_parallel_ops(self, obj):
        real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
        return self._executor.get_num_parallel_ops(real_phase)

    def _get_allreduce_fusion(self, obj):
        real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
        return self._executor.get_allreduce_fusion(real_phase)

    def has_compiled(self, phase='predict'):
        """
        Specify whether have been compiled.

        Args:
            phase (str): The phase name. Default: 'predict'.

        Returns:
            bool, specifies whether the specific graph has been compiled.
        """
        return self._executor.has_compiled(phase)

    def __call__(self, obj, *args, phase='predict'):
        if context.get_context("precompile_only") or _is_role_pserver():
            return None
        return self.run(obj, *args, phase=phase)

    @_wrap_func
    def _exec_pip(self, obj, *args, phase=''):
        """Execute the generated pipeline."""
        fn = obj.construct
        converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
        if not converted:
            raise RuntimeError('Process method parameter is failure')
        args_list = tuple(arguments_dict.values())
        obj.__parse_method__ = parse_method
        return self._executor(args_list, phase)

    def run(self, obj, *args, phase='predict'):
        """
        Run the specific graph.

        Args:
            phase (str): The phase name. Default: 'predict'.

        Returns:
            Tensor/Tuple, return execute result.
        """
        if phase == 'save':
            return self._executor((), phase + '.' + str(obj.create_time))

        phase_real = obj.phase_prefix + phase + '.' + str(obj.create_time)
        if self.has_compiled(phase_real):
            return self._exec_pip(obj, *args, phase=phase_real)
        raise KeyError('{} graph is not exist.'.format(phase_real))

    def del_net_res(self, net_id):
        self._executor.del_net_res(net_id)

    def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
        """Get graph proto from pipeline."""
        if use_prefix:
            exec_id = obj.phase_prefix + exec_id
        if self._executor.has_compiled(exec_id) is False:
            return None
        return self._executor.get_func_graph_proto(exec_id, ir_type)

    def export(self, file_name, graph_id):
        """
        Export graph.

        Args:
            file_name (str): File name of model to export
            graph_id (str): id of graph to be exported
        """
        from .._c_expression import export_graph
        export_graph(file_name, 'AIR', graph_id)

    def fetch_info_for_quant_export(self, exec_id):
        """Get graph proto from pipeline."""
        if self._executor.has_compiled(exec_id) is False:
            return None
        return self._executor.fetch_info_for_quant_export(exec_id)


_executor = _Executor()
_pynative_exec = _PynativeExecutor()

__all__ = ['ms_function']
